#include "POF/Gradient_generator_fullcal_approx.h"
#include <thread>
#include <math.h>
#include <cmath>
#include <stdexcept>
#include <assert.h>

using namespace std;
struct at {
    double approx_loss_table[17][17][17][17];
};
struct mt {
    double multiplier_table[520][3];
};
struct gt {
    double grad_table[520][3];
};
extern double ccon_gradient[n_label][MAX_n_class][MAX_batch_size];
extern double test_case_ccon_gradient[n_label][MAX_n_class][MAX_n_internal_test_case];

extern double test_case_ccon_gradient_tmp[n_label][MAX_n_class][MAX_n_internal_test_case][3];

extern void flatten_ccon_gradient(int n_label, int batch_size, int n_class, double ccon_gradient_list[]);
extern void flatten_test_case_ccon_gradient(int n_label, int n_internal_test_case, int n_class, double test_case_ccon_gradient_list[]);

void fullcalc_approx_grad_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double ccon_gradient_list[]) {

    for (int pm = 0; pm < n_label; pm++) {
        for (int j = 0; j < nn_data->batch_size; j++) {
            double approx_loss_class[nn_data->n_class];
            double expsum = 0;

            for (int i = 0; i < nn_data->n_class; i++) {
                int tmp = nn_data->normalclass[pm][j]; // KKY: check
                nn_data->normalclass[pm][j] = i;
                // KKY: assumption:normaltable, posterior are already set.
                approx_loss_class[i] = calc_approx_loss(nn_data, indata, user_parameter);
                // cout<<approx_loss_class[i]<<endl;
                expsum += exp(nn_data->ccon_w[pm][i][j]);
                nn_data->normalclass[pm][j] = tmp;
            }
            for (int i = 0; i < nn_data->n_class; i++) {
                ccon_gradient[pm][i][j] = 0;
                for (int k = 0; k < nn_data->n_class; k++) {
                    if (k == i)
                        ccon_gradient[pm][i][j] += approx_loss_class[k] * (exp(nn_data->ccon_w[pm][k][j]) / expsum);
                    ccon_gradient[pm][i][j] += -approx_loss_class[k] * (exp(nn_data->ccon_w[pm][k][j]) * exp(nn_data->ccon_w[pm][i][j]) / expsum / expsum);
                }
                // NO_UPDATE
            }
        }
    }

    flatten_ccon_gradient(n_label, nn_data->batch_size, nn_data->n_class, ccon_gradient_list);
}
void calc_approx_grad_one(at a, mt m, gt g, int sec, int j, int in0, struct NN_data* nn_data)
{
    //cout<<"1"<<endl;
    test_case_ccon_gradient_tmp[0][0][sec * 16 + j][in0] = 0;
    register double multiplier_table[16][3];
    register int ans[16];
    for (int k = 0; k < 16; k++)
        for (int l = 0; l < 3; l++)
            multiplier_table[k][l] = m.multiplier_table[sec*16+k][l];
    //cout<<"2"<<endl;
    for (int k = 0; k < 16; k++)
        ans[k] = nn_data->ans[0][sec * 16 + k];
    //cout<<"3"<<endl;
    multiplier_table[j][0] = g.grad_table[sec * 16 + j][0];
    multiplier_table[j][1] = g.grad_table[sec * 16 + j][1];
    multiplier_table[j][2] = g.grad_table[sec * 16 + j][2];
    register int pxi[2] = {}, mxi[2] = {};

    register double value = multiplier_table[0][in0];//���
    pxi[ans[0]] += (in0 <2);
    mxi[ans[0]] += (in0 < 1);
    //cout<<"calculate multiplier_table"<<endl;
    try {

        for (int in1 = 0; in1 < 3; in1++)
        {
            if (multiplier_table[1][in1] < 0.001- 2 * (j==1))
                continue;
            value *= multiplier_table[1][in1];
            pxi[ans[1]] += (in1 < 2);
            mxi[ans[1]] += (in1 < 1);
            for (int in2 = 0; in2 < 3; in2++)
            {
                if (multiplier_table[2][in2] < 0.001 - 2 * (j == 2))
                    continue;
                value *= multiplier_table[2][in2];
                pxi[ans[2]] += (in2 < 2);
                mxi[ans[2]] += (in2 < 1);
                for (int in3 = 0; in3 < 3; in3++)
                {
                    if (multiplier_table[3][in3] < 0.001 - 2 * (j == 3))
                        continue;
                    value *= multiplier_table[3][in3];
                    pxi[ans[3]] += (in3 < 2);
                    mxi[ans[3]] += (in3 < 1);
                    for (int in4 = 0; in4 < 3; in4++)
                    {
                        if (multiplier_table[4][in4] < 0.001 - 2 * (j == 4))
                            continue;
                        value *= multiplier_table[4][in4];
                        pxi[ans[4]] += (in4 < 2);
                        mxi[ans[4]] += (in4 < 1);
                        for (int in5 = 0; in5 < 3; in5++)
                        {
                            if (multiplier_table[5][in5] < 0.001 - 2 * (j == 5))
                                continue;
                            value *= multiplier_table[5][in5];
                            pxi[ans[5]] += (in5 < 2);
                            mxi[ans[5]] += (in5 < 1);
                            for (int in6 = 0; in6 < 3; in6++)
                            {
                                if (multiplier_table[6][in6] < 0.001 - 2 * (j == 6))
                                    continue;
                                value *= multiplier_table[6][in6];
                                pxi[ans[6]] += (in6 < 2);
                                mxi[ans[6]] += (in6 < 1);
                                for (int in7 = 0; in7 < 3; in7++)
                                {
                                    if (multiplier_table[7][in7] < 0.001 - 2 * (j == 7))
                                        continue; 
                                    value *= multiplier_table[7][in7];
                                    pxi[ans[7]] += (in7 < 2);
                                    mxi[ans[7]] += (in7 < 1);
                                    for (int in8 = 0; in8 < 3; in8++)
                                    {
                                        if (multiplier_table[8][in8] < 0.001 - 2 * (j == 8))
                                            continue;
                                        value *= multiplier_table[8][in8];
                                        pxi[ans[8]] += (in8 < 2);
                                        mxi[ans[8]] += (in8 < 1);
                                        for (int in9 = 0; in9 < 3; in9++)
                                        {
                                            if (multiplier_table[9][in9] < 0.001 - 2 * (j == 9))
                                                continue;
                                            value *= multiplier_table[9][in9];
                                            pxi[ans[9]] += (in9 < 2);
                                            mxi[ans[9]] += (in9 < 1);
                                            for (int in10 = 0; in10 < 3; in10++)
                                            {
                                                if (multiplier_table[10][in10] < 0.001 - 2 * (j == 10))
                                                    continue;
                                                value *= multiplier_table[10][in10];
                                                pxi[ans[10]] += (in10 < 2);
                                                mxi[ans[10]] += (in10 < 1);
                                                for (int in11 = 0; in11 < 3; in11++)
                                                {
                                                    if (multiplier_table[11][in11] < 0.001 - 2*(j == 11))
                                                        continue;
                                                    value *= multiplier_table[11][in11];
                                                    pxi[ans[11]] += (in11 < 2);
                                                    mxi[ans[11]] += (in11 < 1);
                                                    for (int in12 = 0; in12 < 3; in12++)
                                                    {
                                                        value *= multiplier_table[12][in12];
                                                        pxi[ans[12]] += (in12 < 2);
                                                        mxi[ans[12]] += (in12 < 1);
                                                        for (int in13 = 0; in13 < 3; in13++)
                                                        {
                                                            value *= multiplier_table[13][in13];
                                                            pxi[ans[13]] += (in13 < 2);
                                                            mxi[ans[13]] += (in13 < 1);
                                                            for (register int in14 = 0; in14 < 3; in14++)
                                                            {
                                                                assert(pxi[0] >=0 && pxi[0] <= 15);
                                                                assert(pxi[1] >=0 && pxi[1] <= 15);
                                                                assert(mxi[0] >=0 && mxi[0] <= 15);
                                                                assert(mxi[1] >=0 && mxi[1] <= 15);
                                                                pxi[ans[14]] += (in14 < 2);
                                                                mxi[ans[14]] += (in14 < 1);
                                                                test_case_ccon_gradient_tmp[0][0][sec * 16 + j][in0] += value * multiplier_table[14][in14] * (multiplier_table[15][0] * a.approx_loss_table[pxi[0] + (!(ans[15]))][mxi[0] + (!ans[15])][pxi[1] + (ans[15])][mxi[1] + (ans[15])] + multiplier_table[15][1] * a.approx_loss_table[pxi[0] + (!(ans[15]))][mxi[0]][pxi[1] + (ans[15])][mxi[1]] + multiplier_table[15][2] * a.approx_loss_table[pxi[0]][mxi[0]][pxi[1]][mxi[1]]);
                                                                pxi[ans[14]] -= (in14 < 2);
                                                                mxi[ans[14]] -= (in14 < 1);
                                                            }
                                                            pxi[ans[13]] -= (in13 < 2);
                                                            mxi[ans[13]] -= (in13 < 1);
                                                            value /= multiplier_table[13][in13];
                                                        }
                                                        pxi[ans[12]] -= (in12 < 2);
                                                        mxi[ans[12]] -= (in12 < 1);
                                                        value /= multiplier_table[12][in12];
                                                    }
                                                    pxi[ans[11]] -= (in11 < 2);
                                                    mxi[ans[11]] -= (in11 < 1);
                                                    value /= multiplier_table[11][in11];
                                                }
                                                pxi[ans[10]] -= (in10 < 2);
                                                mxi[ans[10]] -= (in10 < 1);
                                                value /= multiplier_table[10][in10];
                                            }
                                            pxi[ans[9]] -= (in9 < 2);
                                            mxi[ans[9]] -= (in9 < 1);
                                            value /= multiplier_table[9][in9];
                                        }
                                        pxi[ans[8]] -= (in8 < 2);
                                        mxi[ans[8]] -= (in8 < 1);
                                        value /= multiplier_table[8][in8];
                                    }
                                    pxi[ans[7]] -= (in7 < 2);
                                    mxi[ans[7]] -= (in7 < 1);
                                    value /= multiplier_table[7][in7];
                                }
                                pxi[ans[6]] -= (in6 < 2);
                                mxi[ans[6]] -= (in6 < 1);
                                value /= multiplier_table[6][in6];
                            }
                            pxi[ans[5]] -= (in5 < 2);
                            mxi[ans[5]] -= (in5 < 1);
                            value /= multiplier_table[5][in5];
                        }
                        pxi[ans[4]] -= (in4 < 2);
                        mxi[ans[4]] -= (in4 < 1);
                        value /= multiplier_table[4][in4];
                    }
                    pxi[ans[3]] -= (in3 < 2);
                    mxi[ans[3]] -= (in3 < 1);
                    value /= multiplier_table[3][in3];
                }
                pxi[ans[2]] -= (in2 < 2);
                mxi[ans[2]] -= (in2 < 1);
                value /= multiplier_table[2][in2];
            }
            pxi[ans[1]] -= (in1 < 2);
            mxi[ans[1]] -= (in1 < 1);
            value /= multiplier_table[1][in1];
        }


    }
    catch (long double arr[]) {
        cout << j << endl;
        cout << arr[0] << arr[1] << arr[2] << arr[3] << arr[4] << arr[5] << arr[6] << arr[7] << arr[8] << arr[9] << arr[10] << arr[11] << arr[12] << arr[13] << arr[14] << arr[15] << endl;
        cout << arr[16] << endl;
    }
}
void fullcalc_approx_grad_test_case_ccon(struct NN_data* nn_data, struct User_parameter* user_parameter, struct Indata* indata, double test_case_ccon_gradient_list[]) {

    double xi = user_parameter->xi;
    int n_rank = user_parameter->n_rank;
    int n_internal_test_case = nn_data->n_internal_test_case;
    int n_class = nn_data->n_class;

    Indata* new_indata(new Indata);
    *new_indata = *indata;
    // memcpy(new_indata, indata, sizeof(*indata));

    double main_approxloss = calc_approx_loss(nn_data, indata, user_parameter);

    //일단 safe, unsafe한 intest 개수부터 세기
    
    mt m;
    gt g;
    for (int j = 0; j < n_internal_test_case; j++) {
        m.multiplier_table[j][0] = 1e-15 + exp(nn_data->test_case_ccon_w[0][0][j] - xi) / (exp(nn_data->test_case_ccon_w[0][0][j] - xi) + exp(nn_data->test_case_ccon_w[0][1][j]) + 1e-15);
        m.multiplier_table[j][1] = 1e-15 + exp(nn_data->test_case_ccon_w[0][0][j] + xi) / (exp(nn_data->test_case_ccon_w[0][0][j] + xi) + exp(nn_data->test_case_ccon_w[0][1][j]) + 1e-15) - exp(nn_data->test_case_ccon_w[0][0][j] - xi) / (exp(nn_data->test_case_ccon_w[0][0][j] - xi) + exp(nn_data->test_case_ccon_w[0][1][j]) + 1e-15);
        m.multiplier_table[j][2] = 1e-15 + 1 - exp(nn_data->test_case_ccon_w[0][0][j] + xi) / (exp(nn_data->test_case_ccon_w[0][0][j] + xi) + exp(nn_data->test_case_ccon_w[0][1][j]) + 1e-15);
        g.grad_table[j][0] = 1e-15 + exp(nn_data->test_case_ccon_w[0][0][j] - xi) / (exp(nn_data->test_case_ccon_w[0][0][j] - xi) + exp(nn_data->test_case_ccon_w[0][1][j]) + 1e-15) - pow(exp(nn_data->test_case_ccon_w[0][0][j] - xi), 2) / (pow((exp(nn_data->test_case_ccon_w[0][0][j] - xi) + exp(nn_data->test_case_ccon_w[0][1][j])), 2) + 1e-15);
        g.grad_table[j][2] = 1e-15 - exp(nn_data->test_case_ccon_w[0][0][j] + xi) / (exp(nn_data->test_case_ccon_w[0][0][j] + xi) + exp(nn_data->test_case_ccon_w[0][1][j]) + 1e-15) + pow(exp(nn_data->test_case_ccon_w[0][0][j] + xi), 2) / (pow((exp(nn_data->test_case_ccon_w[0][0][j] + xi) + exp(nn_data->test_case_ccon_w[0][1][j])), 2) + 1e-15);
        g.grad_table[j][1] = -g.grad_table[j][0] - g.grad_table[j][2];
        // cout<<multiplier_table[j][0]<<" "<<multiplier_table[j][1]<<" "<<multiplier_table[j][2]<<" "<<grad_table[j][0]<<" "<<grad_table[j][1]<<" "<<grad_table[j][2]<<endl;
        //각 intest에 대해서 (1, 1)확률, (0,1)확률, (0,0)확률, (1,1)확률 미분값, (0,1)확률 미분값, (0,0)확률 미분값 (여기서 미분은 test_case_ccon_w에 대한 미분)
    }
    //각 table에 대한 approx 결과정리
    for (int sec = 0; sec < n_internal_test_case / 16; sec++)
    {
        int safe_cnt = 0, unsafe_cnt = 0;
        for (int j = sec * 16; j < (sec + 1) * 16; j++)
        {
            safe_cnt += (nn_data->ans[0][j]);
            unsafe_cnt += !(nn_data->ans[0][j]);
        }
        
        at a = {};
        int pxi[2] = {}, mxi[2] = {}, un_cnt = 0, s_cnt = 0;
        for (int s = 0; s < nn_data->n_internal_test_case; s++)
        {
            if (s >= sec * 16 && s < (sec + 1) * 16)
                continue;
            pxi[nn_data->ans[0][s]] += (nn_data->test_case_ccon_w[0][0][s] + xi > nn_data->test_case_ccon_w[0][1][s]);
            mxi[nn_data->ans[0][s]] += (nn_data->test_case_ccon_w[0][0][s] - xi > nn_data->test_case_ccon_w[0][1][s]);
            un_cnt += 1 - nn_data->ans[0][s];
            s_cnt += nn_data->ans[0][s];
        }
        try {
            for (int safe_pxi = 0; safe_pxi <= safe_cnt; safe_pxi++)
                for (int safe_mxi = 0; safe_mxi <= safe_pxi; safe_mxi++)
                    for (int unsafe_pxi = 0; unsafe_pxi <= unsafe_cnt; unsafe_pxi++)
                        for (int unsafe_mxi = 0; unsafe_mxi <= unsafe_pxi; unsafe_mxi++)
                        {
                            new_indata->normaltable[0][0][0][0] = unsafe_mxi + mxi[0];
                            new_indata->normaltable[0][0][0][1] = unsafe_pxi + pxi[0];
                            //'binary에서 1번째가 0번째보다 xi이상 더 크다'는 '0번째가 1번째보다 xi만큼은 작지 않다'의 반대. 따라서 다음과 같이 계산
                            new_indata->normaltable[0][0][1][0] = unsafe_cnt - unsafe_pxi + un_cnt - pxi[0];
                            new_indata->normaltable[0][0][1][1] = unsafe_cnt - unsafe_mxi + un_cnt - mxi[0];
                            new_indata->normaltable[0][1][0][0] = safe_mxi + mxi[1];
                            new_indata->normaltable[0][1][0][1] = safe_pxi + pxi[1];
                            new_indata->normaltable[0][1][1][0] = safe_cnt - safe_pxi + s_cnt - pxi[1];
                            new_indata->normaltable[0][1][1][1] = safe_cnt - safe_mxi + s_cnt - mxi[1];
                            set_posterior(nn_data, user_parameter, new_indata);
                            a.approx_loss_table[unsafe_pxi][unsafe_mxi][safe_pxi][safe_mxi] = calc_approx_loss(nn_data, new_indata, user_parameter);
                            if (isnan(a.approx_loss_table[unsafe_pxi][unsafe_mxi][safe_pxi][safe_mxi]))
                                cout << safe_pxi << safe_mxi << unsafe_pxi << unsafe_mxi << a.approx_loss_table[unsafe_pxi][unsafe_mxi][safe_pxi][safe_mxi] << endl;
                            //안전한 intest중 +xi 개수, -xi 개수, 위험한 intest중 +xi 개수, -xi 개수에 따른 approx_loss 값 출력
                        }
        }
        catch (exception& e)
        {
            cout << e.what() << endl;
        }

        //각 j별로 곱해지는 계수 계산-softmax, softmax 미분
        // cout<<"multiplier_table"<<endl;
        
        //실제 계산
        std::thread threads[400];
        //cout<<"use threads"<<endl;
        for (int j = 0; j < 16; j++)
            for (int in0 = 0; in0 < 3; in0++)
                {
                    calc_approx_grad_one(a, m, g, sec, j, in0, nn_data);
                    //cout<<"4"<< endl;
                    //threads[j * 3 + in0].join();
                }
        //cout<<"use threads"<<sec<< endl;
        for (int j = 0; j < 16; j++) {
            test_case_ccon_gradient[0][0][sec*16 + j] = 0;
            for (int in0 = 0; in0 < 3; in0++) {
                
                test_case_ccon_gradient[0][0][sec*16 + j] += test_case_ccon_gradient_tmp[0][0][sec*16 + j][in0];
            }
            test_case_ccon_gradient[0][1][sec*16 + j] = -test_case_ccon_gradient[0][0][sec*16 + j];
        }
    }
    delete new_indata;

    flatten_test_case_ccon_gradient(n_label, n_internal_test_case, n_class, test_case_ccon_gradient_list);
}
